# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
# pylint: disable=no-name-in-module

import torch
import torch.nn as nn

from diffusion_bandit.dataset_generation import project_onto_sphere
from diffusion_bandit.neural_networks.shape_score_nets import GaussianFourierProjection


class LinearRewardModel(nn.Module):
    def __init__(self, d_ext):
        super(LinearRewardModel, self).__init__()
        self.layer = nn.Linear(d_ext, 1, bias=False)

    def get_features(self, x_batch):
        return x_batch

    def forward(self, x_batch):
        return self.layer(x_batch)


class MLPRewardModel(nn.Module):
    def __init__(
        self,
        d_ext,
        projector,
        radius,
        surface,
        score,
        diffusion_process,
        mode="feedback",  # or "sampling"
    ):
        super(MLPRewardModel, self).__init__()
        self.d_ext = d_ext
        self.radius = radius
        self.register_buffer("projector", projector)
        self.surface = surface  # Added to control surface projection
        self.mode = mode
        self.diffusion_process = diffusion_process
        if score is not None:
            self.score = score.train()

        self.layer = nn.Linear(d_ext, 1, bias=False)
        # Register 'i' as a buffer to ensure it moves with the model's device
        self.register_buffer("i", torch.arange(self.d_ext).view(1, -1))

    def _compute_features(self, x_batch, time=None):
        """
        Use the project_onto_sphere function to compute the features.

        Args:
            x_batch (torch.Tensor): Input points of shape (batch_size, d_ext).

        Returns:
            torch.Tensor: Projected points onto the sphere of shape (batch_size, d_ext).
        """
        # Use the previously defined function
        if self.mode == "feedback":
            projected_x = project_onto_sphere(
                x_data=x_batch,
                projector=self.projector,
                radius=self.radius,
                surface=self.surface,
            )
            return projected_x
        elif self.mode == "sampling":
            var = self.diffusion_process.marginal_prob_std(time=time) ** 2
            mean = self.diffusion_process.marginal_prob_mean_factor(time=time)
            x_0 = (x_batch + var * self.score(time=time, x_batch=x_batch)) / mean
            return x_0
        elif self.mode == "simple":
            return x_batch

    def set_mode(self, mode):
        """
        Sets the mode of the reward model.

        Args:
            mode (str): The mode to set ("feedback" or "sampling").

        Raises:
            ValueError: If the mode is not recognized.
        """
        if mode not in ["feedback", "sampling", "simple"]:
            raise ValueError("mode must be either 'feedback' or 'sampling' or simple")
        self.mode = mode

    def get_features(self, x_batch, time=None):
        return self._compute_features(x_batch=x_batch, time=time)

    def forward(self, x_batch, time=None):
        # Compute the projections of x_batch onto the sphere
        features = self._compute_features(
            x_batch=x_batch, time=time
        )  # Shape: (batch_size, d_ext)
        # Compute the output using the linear layer
        output = self.layer(features)  # Shape: (batch_size, 1)
        return output


def get_ground_truth_reward_model(
    d_ext, projector, radius, surface, name, score=None, diffusion_process=None
):
    if name == "linear":
        reward_model = LinearRewardModel(d_ext)
    elif name == "mlp":
        reward_model = MLPRewardModel(
            d_ext, projector, radius, surface, score, diffusion_process
        )
    else:
        raise NotImplementedError(f"No reward model with name {name}")
    return reward_model


class FeatureNet(nn.Module):
    def __init__(
        self,
        d_ext: int,
        hidden_size: int,
        num_layers: int,
    ):
        super(FeatureNet, self).__init__()

        # Time embedding
        self.time_embedding = GaussianFourierProjection(embed_dim=hidden_size)
        self.activation = nn.LeakyReLU()

        self.input_layer = nn.Sequential(
            nn.Linear(d_ext, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.LeakyReLU(),
        )

        self.hidden_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.hidden_layers.append(
                nn.Sequential(
                    nn.LayerNorm(hidden_size),
                    nn.LeakyReLU(),
                    nn.Linear(hidden_size, hidden_size),
                )
            )

        self.output_layer = nn.Linear(hidden_size, d_ext)

    def forward(self, time: torch.Tensor, x_batch: torch.Tensor) -> torch.Tensor:

        time_emb = self.activation(self.time_embedding(time))
        x_emb = self.input_layer(x_batch)

        # Combine with time embedding
        hidden = x_emb + time_emb

        # Pass through hidden layers with residual connections
        for layer in self.hidden_layers:
            residual = hidden
            hidden = layer(hidden)
            hidden = hidden + residual

        output = self.output_layer(hidden)

        return output
